library("reticulate")
## Warning: package 'reticulate' was built under R version 4.0.5
library("knitr")
library("Hmisc")
## Loading required package: lattice
## Loading required package: survival
## Loading required package: Formula
## Loading required package: ggplot2
## 
## Attaching package: 'Hmisc'
## The following objects are masked from 'package:base':
## 
##     format.pval, units
library("DescTools")
## 
## Attaching package: 'DescTools'
## The following objects are masked from 'package:Hmisc':
## 
##     %nin%, Label, Mean, Quantile
library("stringr")
library("egg")
## Loading required package: gridExtra
library("tidyverse")
## ── Attaching packages ─────────────────────────────────────── tidyverse 1.3.1 ──
## ✓ tibble  3.1.6     ✓ purrr   0.3.4
## ✓ tidyr   1.1.4     ✓ dplyr   1.0.7
## ✓ readr   2.1.1     ✓ forcats 0.5.1
## ── Conflicts ────────────────────────────────────────── tidyverse_conflicts() ──
## x dplyr::combine()   masks gridExtra::combine()
## x dplyr::filter()    masks stats::filter()
## x dplyr::lag()       masks stats::lag()
## x dplyr::src()       masks Hmisc::src()
## x dplyr::summarize() masks Hmisc::summarize()
# set plotting theme 
theme_set(theme_classic() + 
    theme(text = element_text(size = 24)))

# knitr chunk display options     
opts_chunk$set(comment = "",
               results = "hold",
               fig.show = "hold")

# suppress summarise() grouping warning 
options(dplyr.summarise.inform = F)

Load Data

use_condaenv("plinko")
pd = import("pandas")
df_data = pd$read_pickle("../../data/full_dataset_vision_corrected.xz")
# Filter dataset for analysis of specifc experiment/participant set
df_filtered_data = df_data %>% 
  filter(!trial %in% c(305, 309))

Segment judgment, rt, and eye-data

df_data_judge = df_filtered_data %>% 
  select(participant, trial, response) %>% 
  unique()

df_data_rt = df_filtered_data %>% 
  group_by(participant, trial) %>% 
  summarise(rt = tail(t, n=1) - head(t, n=1)) %>% 
  mutate(log_rt = ifelse(rt != 0, log(rt), 0))

Compute Judgment Means

df_data_mean_judge_train = df_data_judge %>% 
  filter(participant %in% seq(1,15)) %>% 
  group_by(trial) %>% 
  summarise(hole1 = sum(response == 1)/n(),
            hole2 = sum(response == 2)/n(),
            hole3 = sum(response == 3)/n()) %>% 
  pivot_longer(c(hole1, hole2, hole3), 
               names_to = "hole",
               values_to = "human_mean")

Compute RT Means

df_data_mean_rt_train = df_data_rt %>% 
  filter(participant %in% seq(1,15)) %>% 
  mutate(log_rt = ifelse(rt != 0, log(rt), 0)) %>% 
  group_by(trial) %>% 
  summarise(mean_rt = mean(rt),
            mean_log_rt = mean(log_rt))

Results Visualization

ggplot(data = df_data_rt, mapping = aes(x = rt)) +
  geom_histogram(fill = "grey", color = "black") +
  ggtitle("Participant Response Times") +
  xlab("Response Time (ms)") +
  theme(plot.title = element_text(hjust=0.5))
`stat_bin()` using `bins = 30`. Pick better value with `binwidth`.

Bandit Model

Top parameters

df_bandit_performance %>% 
  head(1) %>% 
  select(thresholds,
         tradeoffs,
         bws,
         sample_weights)
# A tibble: 1 × 4
  thresholds tradeoffs   bws sample_weights
       <dbl>     <dbl> <dbl>          <dbl>
1          1       0.1    40            500
df_model_judge_rt = read.csv("../python/model/model_performance/grid_judgment_rt/bandit_runs_30_threshold_1.0_tradeoff_0.1_sample_weight_500_bw_40.0_look_probs_0.5_0.8_0.02_0.5_noise_params_0.0_10.0_0.2_0.8_0.2_heuristic_prior_trial_0_150.csv") %>% select(-X)

Judgments

df_model_mean_judge = df_model_judge_rt %>% 
  mutate(judgment = judgment + 1,
         judgment=factor(judgment)) %>%
  group_by(trial, judgment) %>% 
  summarise(model_mean = n()/(max(run)+1)) %>% 
  ungroup() %>% 
  complete(trial, judgment,
           fill = list(model_mean=0))

df_data_mean_judge_full = df_data_judge %>% 
  mutate(hole1 = as.numeric(response == 1),
         hole2 = as.numeric(response == 2),
         hole3 = as.numeric(response == 3)) %>% 
  select(-response) %>% 
  pivot_longer(c(hole1, hole2, hole3),
               names_to = "hole",
               values_to = "response") %>% 
  mutate(response = response*100) %>% 
  group_by(trial, hole) %>% 
  do(data.frame(rbind(smean.cl.boot(.$response)))) %>% 
  rename(human_mean = Mean,
         lower = Lower,
         upper = Upper)

df_human_mean_judge = df_data_mean_judge_full %>%
  mutate(hole = as.factor(str_sub(hole, -1, -1))) %>%
  rename(judgment = hole)

df_to_show = left_join(df_model_mean_judge,
                       df_human_mean_judge, 
                       by=c("trial", "judgment")) %>% 
  mutate(model = "Bandit") 

model_cor = round(cor(df_to_show$model_mean, df_to_show$human_mean), digits=2)
model_rmse = round(RMSE(df_to_show$model_mean, df_to_show$human_mean), digits=2)

ggplot(data = df_to_show, mapping = aes(x = model_mean,
                                        y=human_mean)) +  
  geom_abline(slope = 100,
              intercept = 0,
              linetype="dotted") +
  geom_linerange(mapping = aes(ymin = lower,
                               ymax = upper),
                 alpha = 0.2) +
  geom_point(alpha=0.5) +
  geom_smooth(method = "lm",
              formula = y ~ x) +
  facet_grid(~ model) +
  xlab("Model Prediction") +
  ylab("Participant Selection %") +
  annotate("text",
           label = paste("r: ", model_cor),
           x=0.0,
           y=100,
           hjust=0) +
  annotate("text",
           label = paste("rmse: ", model_rmse),
           x=0.0,
           y=95,
           hjust = 0) +
  theme(plot.title = element_text(size=20, hjust=0.5),
        axis.title = element_text(size=16),
        axis.text = element_text(size=10))



ggsave("figures/bandit_judgment.pdf", height = 4, width = 5)

df_bandit_judge = df_to_show %>% 
  mutate(model = "Bandit") %>% 
  rename(prediction = model_mean)

Response Time

df_model_mean_rt = df_model_judge_rt %>% 
  mutate(time_measure = num_sims + num_looks,
         log_time = ifelse(time_measure != 0, log(time_measure), time_measure)) %>% 
  group_by(trial) %>% 
  summarise(mean_time = mean(time_measure),
            mean_log_time = mean(log_time))

df_data_mean_rt = df_data_rt %>%
  group_by(trial) %>% 
  summarise(mean_rt = mean(rt),
            mean_log_rt = mean(log(rt)))

df_to_show = left_join(df_model_mean_rt,
                       df_data_mean_rt,
                       by = c("trial"))

model_cor = round(cor(df_to_show$mean_time, df_to_show$mean_rt), digits=2)
model_rmse = round(RMSE(df_to_show$mean_time, df_to_show$mean_rt), digits=2)

ggplot(data = df_to_show, mapping = aes(x = mean_time, y = mean_rt)) +
  geom_point(alpha = 0.7,
             shape=16) +
  geom_smooth(method = "lm",
              formula = y ~ x) +
  # geom_label(mapping = aes(label = trial)) +
  ggtitle("Bandit Response Time") +
  xlab("Model Mean Looks Across Runs") +
  ylab("Participant Mean log Response Time") +
  annotate("text",
           label = paste("r =", model_cor),
           x=20,
           y=2500,
           hjust=0) +
  annotate("text",
           label = paste("rmse =", model_rmse),
           x=20,
           y=2200,
           hjust=0) +
  theme(plot.title = element_text(size=20, hjust=0.5),
        axis.title = element_text(size=14),
        axis.text = element_text(size=12))

ggsave("figures/bandit_rt.png", height = 4, width = 5)

df_data_mean_rt = df_data_rt %>% 
  group_by(trial) %>% 
  do(data.frame(rbind(smean.cl.boot(.$log_rt)))) %>% 
  rename(mean_log_rt = Mean,
         upper = Upper,
         lower = Lower)

df_to_show = left_join(df_model_mean_rt,
                       df_data_mean_rt,
                       by = c("trial")) %>% 
  mutate(model = "Bandit")

model_cor = round(cor(df_to_show$mean_log_time, df_to_show$mean_log_rt), digits=2)
model_rmse = round(RMSE(df_to_show$mean_log_time, df_to_show$mean_log_rt), digits=2)

xvals = c(0.5, 1.0, 1.5, 2.0, 2.5)
yvals = c(7.0, 7.5, 8.0, 8.5)

ggplot(data = df_to_show, mapping = aes(x = mean_log_time, y = mean_log_rt)) +
  geom_linerange(mapping = aes(ymin = lower,
                               ymax = upper),
                 alpha = 0.15) +
  geom_point(alpha = 0.7,
             shape=16) +
  geom_smooth(method = "lm",
              formula = y ~ x) +
  facet_grid(~ model) +
  xlab("Model Mean log Actions") +
  ylab("Mean log Response Time") +
  annotate("text",
           label = paste("r: ", model_cor),
           size = 6,
           x=0.5,
           y=8.8,
           hjust=0) +
  annotate("text",
           label = paste("rmse: ", model_rmse),
           size =6,
           x=0.5,
           y=8.68,
           hjust=0) +
  theme(plot.title = element_text(size=20, hjust=0.5),
        axis.title = element_text(size=24),
        axis.text = element_text(size=18),
        plot.margin = margin(10, 0, 0, 10))

ggsave("figures/bandit_log_rt.pdf", height = 4, width = 5)

df_bandit_rt = df_to_show %>% 
  rename(time_measure = mean_time,
         log_time = mean_log_time)
ggplot(df_model_judge_rt, mapping = aes(x = num_looks)) +
  geom_histogram(bins=30, fill = "grey", color = "black") +
  ggtitle("Bandit Looks Histogram") +
  xlab("Number of Looks") +
  theme(plot.title = element_text(hjust=0.5))

Fixed Sample Model

Top parameters

df_fixed_performance %>% 
  head(1) %>% 
  select(num_samples,
         bws)
# A tibble: 1 × 2
  num_samples   bws
        <dbl> <dbl>
1          80    12
df_fixed_sample_judge_rt = read.csv("../python/model/model_performance/grid_judgment_rt/fixed_sample_num_samples_80_bw_12.0_look_probs_0.5_0.8_0.02_0.5_noise_params_0.0_10.0_0.2_0.8_0.2_trial_0_150.csv") %>%  select(-X)

Judgments

df_fixed_sample_long = df_fixed_sample_judge_rt %>% 
  select(trial, hole1, hole2, hole3) %>% 
  pivot_longer(c(hole1, hole2, hole3),
               names_to = "hole",
               values_to = "prediction")
df_to_show = df_fixed_sample_long %>% 
  left_join(df_data_mean_judge_full, by = c("trial", "hole")) %>% 
  mutate(model = "Fixed Sample")

fixed_sample_cor = round(cor(df_to_show$prediction, df_to_show$human_mean), digits = 2)
fixed_sample_rmse = round(RMSE(df_to_show$prediction, df_to_show$human_mean), digits = 2)

ggplot(df_to_show, mapping = aes(x = prediction, y = human_mean)) +
  geom_abline(slope = 100,
              intercept = 0,
              linetype = "dotted") +
  geom_linerange(mapping = aes(ymin = lower,
                               ymax = upper),
                 alpha=0.2) +
  geom_point(alpha=0.5,
             shape=16) +
  geom_smooth(method = "lm",
              formula = y ~ x) +
  annotate("text",
           label = paste("r:", fixed_sample_cor),
           x = 0.0,
           y = 100,
           hjust = 0) +
  annotate("text",
           label = paste("rmse:", fixed_sample_rmse),
           x = 0.0,
           y = 95,
           hjust = 0) +
  facet_grid(~ model) + 
  xlab("Model Prediction") +
  ylab("Participant Mean Judgment") +
  theme(plot.title = element_text(size=20, 
                                  hjust=0.5),
        axis.title = element_text(size=16),
        axis.text = element_text(size=10))

ggsave("figures/fixed_sample_judgments.pdf", height=4, width=5)

df_fixed_judge = df_to_show %>% 
  mutate(judgment = as.factor(str_sub(hole, -1, -1)),
         model = "Uniform Sampler") %>% 
  select(-hole)

Response Time

df_to_show = df_fixed_sample_judge_rt %>% 
  select(trial, num_sims, num_looks) %>% 
  mutate(time_measure = num_sims + num_looks,
         log_time = log(time_measure)) %>% 
  left_join(df_data_mean_rt, by = "trial") %>% 
  mutate(model = "Uniform Sampler")

fixed_sample_rt_cor = round(cor(df_to_show$log_time, df_to_show$mean_log_rt), digits = 2)
fixed_sample_rt_rmse = round(RMSE(df_to_show$log_time, df_to_show$mean_log_rt), digits = 2)

ggplot(data = df_to_show, mapping = aes(x = log_time, y = mean_log_rt)) +
  geom_linerange(mapping = aes(ymin = lower,
                               ymax = upper),
                 alpha = 0.15) +
  geom_point(alpha=0.5,
             shape=16) +
  geom_smooth(method = "lm",
              formula = y ~ x) +
  facet_grid(~ model) +
  xlab("Model Prediction") +
  ylab("Participant Mean log \n Response Time") +
  annotate("text",
           label = paste("r:", fixed_sample_rt_cor),
           x=6.3,
           y=8.8,
           size=6,
           hjust=0) +
  annotate("text",
           label = paste("rmse:", fixed_sample_rt_rmse),
           x=6.3,
           y=8.68,
           size=6,
           hjust=0) +
  theme(plot.title = element_text(size=20, hjust=0.5),
        axis.title = element_text(size=24),
        axis.text = element_text(size=18),
        plot.margin = margin(10,0,0,10))

ggsave("figures/fixed_sample_rt.pdf", height=4, width = 5)

df_fixed_rt = df_to_show %>% 
  select(-c(num_looks, num_sims))
ggplot(df_fixed_sample_judge_rt, mapping = aes(x = num_looks))+ 
  geom_histogram(fill = "grey", color = "black") +
  ggtitle("Fixed Sample Looks Histogram") +
  xlab("Number of Looks") +
  theme(plot.title = element_text(hjust=0.5))
`stat_bin()` using `bins = 30`. Pick better value with `binwidth`.

# Cogsci Figures

Judgments

df_to_show = rbind(df_bandit_judge,
                   df_fixed_judge)

df_sum_stat = df_to_show %>% 
  group_by(model) %>% 
  summarise(r = round(cor(prediction, human_mean), digits = 2),
            rmse = round(RMSE(prediction, human_mean), digits = 2))

ggplot(df_to_show, mapping = aes(x = prediction, y = human_mean)) +
  geom_abline(slope = 100,
              intercept = 0,
              linetype = "dotted") +
  geom_linerange(mapping = aes(ymin = lower,
                               ymax = upper),
                 alpha = 0.2) +
  geom_point(alpha = 0.5) +
  geom_smooth(method = "lm") + 
  geom_text(data = df_sum_stat,
            x = 0.0,
            y = 100,
            size = 6,
            hjust = 0,
            mapping = aes(label = paste("r: ", r, sep = ""))) +
  geom_text(data = df_sum_stat,
            x = 0.0,
            y = 93,
            size = 6,
            hjust = 0,
            mapping = aes(label = paste("rmse: ", rmse, sep = ""))) +
  facet_wrap(~ model) +
  xlab("Model Prediction") + 
  ylab("Participant % Selection") +
  theme(plot.title = element_text(size=20, hjust=0.5),
        axis.title = element_text(size=24),
        axis.text = element_text(size=18),
        panel.spacing = unit(2, "lines"))
`geom_smooth()` using formula 'y ~ x'
ggsave("figures/model_judgment.pdf",
       width = 10,
       height = 4)
`geom_smooth()` using formula 'y ~ x'

RT

df_to_show = rbind(df_bandit_rt,
                   df_fixed_rt)

df_sum_stat = df_to_show %>% 
  group_by(model) %>% 
  summarise(r = round(cor(log_time, mean_log_rt), digits = 2),
            rmse = round(RMSE(log_time, mean_log_rt), digits = 2))


ggplot(df_to_show, mapping = aes(x = log_time,
                                 y = mean_log_rt)) +
  geom_linerange(mapping = aes(ymin = lower,
                               ymax = upper),
                 alpha = 0.3) +
  geom_point(alpha = 0.7,
             shape = 16) +
  geom_smooth(method = "lm") +
  geom_text(data = df_sum_stat,
            x = 0.5,
            y = 8.7,
            hjust = 0,
            size = 4,
            mapping = aes(label = paste("r:", r))) +
  geom_text(data = df_sum_stat,
            x = 0.5,
            y = 8.6,
            hjust = 0,
            size = 4,
            mapping = aes(label = paste("rmse:", rmse))) +
  facet_wrap(~ model,
             scales = "free_x") +
  xlab("Model Prediction") +
  ylab("Mean log Response Time") +
  theme(plot.title = element_text(size=20, hjust=0.5),
        axis.title = element_text(size=14),
        axis.text = element_text(size=12),
        panel.spacing = unit(2, "lines"))
`geom_smooth()` using formula 'y ~ x'

EMD

df_emd_bandit = read.csv("../python/model/model_performance/emd/top_bandit.csv") %>% 
  select(trial, distance) %>% 
  mutate(model = "Bandit")

df_emd_fixed_sample = read.csv("../python/model/model_performance/emd/top_fixed_sample.csv") %>%
  select(trial, distance) %>%
  mutate(trial = factor(trial),
         model = "Uniform Sampler")

df_emd_baseline = read.csv("../python/model/model_performance/emd/emd_baseline.csv") %>% 
  select(-X) %>% 
  mutate(trial = factor(trial),
         model = "Baseline")
to_highlight = c()

set.seed(1)

df_to_show = rbind(df_emd_bandit, df_emd_fixed_sample, df_emd_baseline) %>% 
  mutate(model = factor(model,
                        levels = c("Bandit", "Uniform Sampler", "Baseline"),
                        labels = c(1,2,3)),
         model = as.numeric(as.character(model)),
         highlight = trial %in% to_highlight,
         model_jitter = model + runif(n = n(),
                                      min = -0.15,
                                      max = 0.15)) 

# ggplot(df_to_show, mapping = aes(x = model, y = distance)) +
ggplot(df_to_show, mapping = aes(x = model, 
                                 y = distance, 
                                 color = highlight)) +
  geom_line(mapping = aes(x = model_jitter, group = trial), 
            alpha = 0.05) +
  geom_point(mapping = aes(x = model_jitter),
             alpha = 0.5,
             shape=16,
             size=3) +
  stat_summary(fun.data = "mean_cl_boot", color = "red", size=0.8) +
  scale_x_continuous(breaks = c(1,2,3), labels = c("Bandit", "Uniform Sampler", "Baseline")) + 
  scale_color_manual(values = c("black", "magenta3")) +
  ylab("Earth Mover's Distance") +
  theme(legend.title = element_blank(),
        legend.position = "none",
        axis.title.y = element_text(size=24),
        axis.title.x = element_blank(),
        axis.text = element_text(size=20))

ggsave("figures/emd_comparison.pdf",
       height = 5,
       width = 10)

df_emd = rbind(df_emd_bandit,
               df_emd_fixed_sample,
               df_emd_baseline)

df_emd %>% 
  group_by(model) %>% 
  do(data.frame(rbind(round(smean.cl.boot(.$distance), 2))))
# A tibble: 3 × 4
# Groups:   model [3]
  model            Mean Lower Upper
  <chr>           <dbl> <dbl> <dbl>
1 Bandit           51.2  48.8  53.7
2 Baseline        118.  116.  121. 
3 Uniform Sampler  77.1  74.6  79.7
df_emd = rbind(df_emd_bandit,
               df_emd_fixed_sample)

df_emd %>% 
  mutate(model = ifelse(model == "Uniform Sampler", "uniform_sampler", "bandit")) %>% 
  pivot_wider(names_from = model,
              values_from = distance) %>% 
  mutate(diff = uniform_sampler - bandit) %>% 
  arrange(desc(diff))
# A tibble: 150 × 4
   trial bandit uniform_sampler  diff
   <chr>  <dbl>           <dbl> <dbl>
 1 190     32.0           112.   79.5
 2 67      37.1           116.   79.0
 3 72      33.9           109.   74.9
 4 12      49.5           113.   63.7
 5 158     35.3            96.4  61.1
 6 114     49.6           108.   58.8
 7 46      25.5            84.1  58.5
 8 223     36.5            95.0  58.5
 9 254     34.3            91.8  57.6
10 20      35.5            90.3  54.8
# … with 140 more rows
ggplot(df_to_show, mapping = aes(x = distance, fill = model)) +
  geom_histogram(color = "black") +
  facet_wrap(~model, nrow = 3)
`stat_bin()` using `bins = 30`. Pick better value with `binwidth`.